import torch as T
import numpy as np
import tqdm
import pickle
import gzip
import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--save_path", default='results/tb_r0.pkl.gz', type=str)
parser.add_argument("--device", default='cuda', type=str)
parser.add_argument("--seed", default='0', type=int)

# GFlowNet
parser.add_argument("--method", default='flownet', type=str)
parser.add_argument("--loss", default='trajectory_balance', type=str, help="Detailed balance loss or Trajectory balance loss")
parser.add_argument("--pb_type", default='flexible', type=str, help="flexible or uniform")
parser.add_argument("--learning_rate_model", default=1e-3, help="Learning rate for model parameters", type=float)
parser.add_argument("--learning_rate_z", default=1e-1, help="Learning rate for Z", type=float)
parser.add_argument("--mbsize", default=16, help="Minibatch size", type=int)
parser.add_argument("--n_hid", default=256, type=int)
parser.add_argument("--n_layers", default=2, type=int)
parser.add_argument("--n_train_steps", default=62500, type=int)
parser.add_argument("--num_empirical_loss", default=200000, type=int,
                    help="Number of samples used to compute the empirical distribution loss")
parser.add_argument('--exp_weight', default=0.0, type=float)

# Env
parser.add_argument("--horizon", default=8, type=int)
parser.add_argument('--r', default='1e-3', type=float)
parser.add_argument("--ndim", default=4, type=int)

def make_mlp(l, act=T.nn.LeakyReLU(), tail=[]):
        return T.nn.Sequential(*(sum(
            [[T.nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
            for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))

def main(args):
    device = args.device
    horizon = args.horizon
    ndim = args.ndim
    n_hid = args.n_hid
    n_layers = args.n_layers
    bs = args.mbsize
    detailed_balance = (args.loss=="detailed balance")
    uniform_pb = (args.pb_type=="uniform")
    r = args.r

    print('loss is', 'DB' if detailed_balance else 'TB')
    print('Uniform pb: ', uniform_pb)

    def log_reward(x):
        ax = abs(x / (horizon-1) * 2 - 1)
        return ((ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + r).log()

    j = T.zeros((horizon,)*ndim+(ndim,))

    for i in range(ndim):
        jj = T.linspace(0,horizon-1,horizon)
        for _ in range(i): jj = jj.unsqueeze(1)
        j[...,i] = jj

    truelr = log_reward(j)
    print('total reward', truelr.view(-1).logsumexp(0))
    true_dist = truelr.flatten().softmax(0).cpu().numpy()

    def toin(z):
        return T.nn.functional.one_hot(z,horizon).view(z.shape[0],-1).float()

    Z = T.zeros((1,)).to(device)
    if detailed_balance:
        model = make_mlp([ndim*horizon] + [n_hid] * n_layers + [2*ndim+2]).to(device)
        opt = T.optim.Adam([ {'params':model.parameters(), 'lr':args.learning_rate_model} ])
    else:
        model = make_mlp([ndim*horizon] + [n_hid] * n_layers + [2*ndim+1]).to(device)
        opt = T.optim.Adam([ {'params':model.parameters(), 'lr':args.learning_rate_model}, {'params':[Z], 'lr':args.learning_rate_z} ])
        Z.requires_grad_()

    losses = []
    zs = []
    all_visited_state = []
    all_visited = []
    first_visit = -1 * np.ones_like(true_dist)
    l1log = []
    kllog = []

    for it in tqdm.trange(args.n_train_steps):
        opt.zero_grad()
        
        z = T.zeros((bs,ndim), dtype=T.long).to(device)
        done = T.full((bs,), False, dtype=T.bool).to(device)
            
        action = None
        
        if detailed_balance:
            ll_diff = T.zeros((ndim*horizon, bs)).to(device)
        else:
            ll_diff = T.zeros((bs,)).to(device)
            ll_diff += Z
        
        i = 0
        while T.any(~done):
            pred = model(toin(z[~done]))
            
            edge_mask = T.cat([ (z[~done]==horizon-1).float(), T.zeros(((~done).sum(),1), device=device) ], 1)
            logits = (pred[...,:ndim+1] - 1000000000*edge_mask).log_softmax(1)

            init_edge_mask = (z[~done]== 0).float()
            back_logits = ( (0 if uniform_pb else 1)*pred[...,ndim+1:2*ndim+1] - 1000000000*init_edge_mask).log_softmax(1)

            if detailed_balance:
                log_flow = pred[...,2*ndim+1]
                ll_diff[i,~done] += log_flow
                if i>0: ll_diff[i-1,~done] -= log_flow
                else: Z[:] = log_flow[0].item()
            
            if action is not None: 
                if detailed_balance:
                    ll_diff[i-1,~done] -= back_logits.gather(1, action[action!=ndim].unsqueeze(1)).squeeze(1)
                else:
                    ll_diff[~done] -= back_logits.gather(1, action[action!=ndim].unsqueeze(1)).squeeze(1)
                
            exp_weight= args.exp_weight
            temp = 1
            sample_ins_probs = (1-exp_weight)*(logits/temp).softmax(1) + exp_weight*(1-edge_mask) / (1-edge_mask+0.0000001).sum(1).unsqueeze(1)
            
            action = sample_ins_probs.multinomial(1)
            if detailed_balance:
                ll_diff[i,~done] += logits.gather(1, action).squeeze(1)
            else:
                ll_diff[~done] += logits.gather(1, action).squeeze(1)

            terminate = (action==ndim).squeeze(1)
            for x in z[~done][terminate]: 
                state = (x.cpu()*(horizon**T.arange(ndim))).sum().item()
                if first_visit[state]<0: first_visit[state] = it
                all_visited_state.append(state)
                all_visited.append(list(x.cpu().numpy()))

                    
            if detailed_balance:
                termination_mask = ~done
                termination_mask[~done] &= terminate
                ll_diff[i,termination_mask] -= log_reward(z[~done][terminate].float())  
            done[~done] |= terminate

            with T.no_grad():
                z[~done] = z[~done].scatter_add(1, action[~terminate], T.ones(action[~terminate].shape, dtype=T.long, device=device))
            
            i += 1

        lens = z.sum(1)+1
        if not detailed_balance:
            lr = log_reward(z.float())
            ll_diff -= lr

        loss = (ll_diff**2).sum()/(lens.sum() if detailed_balance else bs)
            
        loss.backward()

        opt.step()

        losses.append(loss.item())
    
        zs.append(Z.item())

        if it%100==0: 
            print('loss =', np.array(losses[-100:]).mean(), 'Z =', Z.item())
            emp_dist = np.bincount(all_visited_state[-args.num_empirical_loss:], minlength=len(true_dist)).astype(float)
            emp_dist /= emp_dist.sum()
            l1 = np.abs(true_dist-emp_dist).mean()
            kl = -(true_dist * np.log((emp_dist / true_dist)+1e-13)).sum()
            print('L1 =', l1)
            print('KL =', kl)
            l1log.append((len(all_visited), l1))
            kllog.append((len(all_visited), kl))

    pickle.dump(
            {'losses': np.float32(losses),
            'zs': np.float32(zs),
            'params': [i.data.to('cpu').numpy() for i in model.parameters()],
            'visited': np.int8(all_visited),
            'emp_dist_loss': l1log,
            'kl': kllog,
            'state_dict': model.state_dict(),
            'args':args},
            gzip.open(args.save_path, 'wb'))
    
    #T.save(model.to('cpu'), args.save_path.strip('.pkl.gz'))

if __name__ == '__main__':
    args = parser.parse_args()

    T.manual_seed(args.seed)
    np.random.seed(args.seed)
    # T.set_num_threads(1)

    main(args)
